from __future__ import print_function
import math
import random
import sys
import time
from functools import partial
import tensorflow as tf
from optparse import OptionParser

from tensorflow_core.python.keras.callbacks import TensorBoard
from ioutils_direct import get_all_batch, INVALID_BOND, binary_fdim
import itertools

'''
Script for training the core finder model

Key changes from NIPS paper version:
- Addition of "rich" options for atom featurization with more informative descriptors
- Predicted reactivities are not 1D, but 5D and explicitly identify what the bond order of the product should be
'''

NK = 20
NK0 = 10

parser = OptionParser()
parser.add_option("-t", "--train", dest="train_path")
parser.add_option("-m", "--save_dir", dest="save_path")
parser.add_option("-b", "--batch", dest="batch_size", default=20)
parser.add_option("-w", "--hidden", dest="hidden_size", default=100)
parser.add_option("-d", "--depth", dest="depth", default=1)
parser.add_option("-l", "--max_norm", dest="max_norm", default=5.0)
parser.add_option("-r", "--rich", dest="rich_feat", default=False)
opts, args = parser.parse_args()

'''
batch_size = int(opts.batch_size)
hidden_size = int(opts.hidden_size)
depth = int(opts.depth)
max_norm = float(opts.max_norm)
train_path = opts.train_path
'''
opts.train_path = "data/train.txt.proc"
opts.save_path = "log/model_save"

batch_size = 1
hidden_size = 300
depth = 3
max_norm = 5.0
# train_path = '../data/train.txt.proc'
# save_path = '../mymodel'
opts.rich_feat = 0
if opts.rich_feat:
    from mol_graph_rich import atom_fdim as adim, bond_fdim as bdim, max_nb, smiles2graph_list_pack as _s2g
else:
    from mol_graph import atom_fdim as adim, bond_fdim as bdim, max_nb, smiles2graph_list_pack as _s2g

smiles2graph_batch = partial(
    _s2g, idxfunc=lambda x: x.GetIntProp('molAtomMapNumber') - 1)


def gen():
    for i in itertools.count(1):
        yield (i, [1] * i)


def count(s):
    c = 0
    for i in range(len(s)):
        if s[i] == ':':
            c += 1
    return c


def gen_train(path="data/train.txt.proc", total_count=20):
    '''Process data from a text file; bin by number of heavy atoms
    since that will determine the input sizes in each batch'''
    bucket_size = [10, 20, 30, 40, 50, 60, 80, 100, 120, 150]
    buckets = [[] for i in range(len(bucket_size))]

    with open(path, 'r') as f:
        for line in f:
            r, e = line.strip("\r\n ").split()
            c = count(r)
            for i in range(len(bucket_size)):
                if c <= bucket_size[i]:
                    buckets[i].append((r, e))
                    break

    for i in range(len(buckets)):
        random.shuffle(buckets[i])

    head = [0] * len(buckets)
    avil_buckets = [i for i in range(len(buckets)) if len(buckets[i]) > 0]
    for iter_count in range(total_count):
        src_batch, edit_batch = [], []
        bid = random.choice(avil_buckets)
        bucket = buckets[bid]
        it = head[bid]
        data_len = len(bucket)
        for i in range(batch_size):
            react = bucket[it][0].split('>')[0]
            src_batch.append(react)
            edits = bucket[it][1]
            edit_batch.append(edits)
            it = (it + 1) % data_len
        head[bid] = it
        yield (src_batch, edit_batch)
        #
        # # Prepare batch for TF
        # src_tuple = smiles2graph_batch(src_batch)
        # _input_atom, _input_bond, _atom_graph, _bond_graph, _num_nbs, _node_mask = src_tuple
        # cur_bin, cur_label, sp_label = get_all_batch(
        #     zip(src_batch, edit_batch))
        # yield (iter_count, _input_atom, _input_bond, _atom_graph, _bond_graph, _num_nbs, _node_mask, cur_bin, cur_label, sp_label)
        #


def gen_eval(path="data/valid.txt.proc", total_count=20):
    '''Process data from a text file; bin by number of heavy atoms
    since that will determine the input sizes in each batch'''
    bucket_size = [10, 20, 30, 40, 50, 60, 80, 100, 120, 150]
    buckets = [[] for i in range(len(bucket_size))]

    with open(path, 'r') as f:
        for line in f:
            r, e = line.strip("\r\n ").split()
            c = count(r)
            for i in range(len(bucket_size)):
                if c <= bucket_size[i]:
                    buckets[i].append((r, e))
                    break

    for i in range(len(buckets)):
        random.shuffle(buckets[i])

    head = [0] * len(buckets)
    avil_buckets = [i for i in range(len(buckets)) if len(buckets[i]) > 0]
    for iter_count in range(total_count):
        src_batch, edit_batch = [], []
        bid = random.choice(avil_buckets)
        bucket = buckets[bid]
        it = head[bid]
        data_len = len(bucket)
        for i in range(batch_size):
            react = bucket[it][0].split('>')[0]
            src_batch.append(react)
            edits = bucket[it][1]
            edit_batch.append(edits)
            it = (it + 1) % data_len
        head[bid] = it
        # Prepare batch for TF
        src_tuple = smiles2graph_batch(src_batch)
        _input_atom, _input_bond, _atom_graph, _bond_graph, _num_nbs, _node_mask = src_tuple
        cur_bin, cur_label, sp_label = get_all_batch(
            zip(src_batch, edit_batch))
        yield (iter_count, _input_atom, _input_bond, _atom_graph, _bond_graph, _num_nbs, _node_mask, cur_bin, cur_label, sp_label)


training_data = tf.data.Dataset.from_generator(
    gen_train,
    (tf.int32, tf.float32, tf.float32, tf.int32, tf.int32,
     tf.int32, tf.float32, tf.float32, tf.int32, tf.int32),
    ((tf.TensorShape([]),
      tf.TensorShape([batch_size, None, adim]),    # input_atom
      tf.TensorShape([batch_size, None, bdim]),    # input_bond
      tf.TensorShape([batch_size, None, max_nb, 2]),    # atom_graph
      tf.TensorShape([batch_size, None, max_nb, 2]),    # bond_graph
      tf.TensorShape([batch_size, None]),   # num_nbs
      tf.TensorShape([batch_size, None]),   # node_mask
      tf.TensorShape([batch_size, None, None, binary_fdim]),    # binary
      tf.TensorShape([batch_size, None]),   # label
      tf.TensorShape([batch_size, None])    # sp_label
      )))


eval_data = tf.data.Dataset.from_generator(
    gen_eval,
    (tf.int32, tf.float32, tf.float32, tf.int32, tf.int32,
     tf.int32, tf.float32, tf.float32, tf.int32, tf.int32),
    ((tf.TensorShape([]),
      tf.TensorShape([batch_size, None, adim]),    # input_atom
      tf.TensorShape([batch_size, None, bdim]),    # input_bond
      tf.TensorShape([batch_size, None, max_nb, 2]),    # atom_graph
      tf.TensorShape([batch_size, None, max_nb, 2]),    # bond_graph
      tf.TensorShape([batch_size, None]),   # num_nbs
      tf.TensorShape([batch_size, None]),   # node_mask
      tf.TensorShape([batch_size, None, None, binary_fdim]),    # binary
      tf.TensorShape([batch_size, None]),   # label
      tf.TensorShape([batch_size, None])    # sp_label
      )))
# training_data = training_data.batch(1)
# eval_data = eval_data.batch(1)


class Linear(tf.keras.layers.Layer):

    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        stddev = min(1.0 / math.sqrt(input_shape[-1]), 0.1)
        w_init = tf.random_normal_initializer(stddev=stddev)
        self.w = tf.Variable(initial_value=w_init(shape=(input_shape[-1], self.units),
                                                  dtype='float32'),
                             trainable=True)
        self.b = self.add_weight(shape=(self.units,),
                                 initializer='random_normal',
                                 trainable=True)

    def call(self, inputs, **kwargs):
        return tf.matmul(inputs, self.w) + self.b


class linearND(tf.keras.layers.Layer):

    def __init__(self, units=300, init_bias=None):
        super(linearND, self).__init__()
        self.units = units
        self.init_bias = init_bias

    def build(self, input_shape):
        stddev = min(1.0 / math.sqrt(int(input_shape[-1])), 0.1)
        w_init = tf.random_normal_initializer(stddev=stddev)
        self.w = tf.Variable(initial_value=w_init(shape=(input_shape[-1], self.units),
                                                  dtype='float32'),
                             trainable=True)
        if self.init_bias is not None:
            self.b = self.add_weight(shape=(self.units,),
                                     initializer='random_normal',
                                     trainable=True)

    def call(self, inputs, **kwargs):
        input_shape = inputs.get_shape().as_list()
        ndim = len(input_shape)
        X_shape = tf.gather(tf.shape(inputs), list(range(ndim - 1)))
        target_shape = tf.concat([X_shape, [self.units]], 0)
        exp_input = tf.reshape(inputs, [-1, input_shape[-1]])
        if self.init_bias is None:
            res = tf.matmul(exp_input, self.w)
        else:
            res = tf.matmul(exp_input, self.w) + self.b
        res = tf.reshape(res, target_shape)
        res.set_shape(input_shape[:-1] + [self.units])
        return res


class Data_prepare_layer(tf.keras.layers.Layer):
        def __init__(self):
            pass

        def call(self, inputs, **kwargs):
            src_batch, edit_batch = inputs
            # Prepare batch for TF
            src_tuple = smiles2graph_batch(src_batch)
            _input_atom, _input_bond, _atom_graph, _bond_graph, _num_nbs, _node_mask = src_tuple
            cur_bin, cur_label, sp_label = get_all_batch(
                zip(src_batch, edit_batch))
            return (_input_atom, _input_bond, _atom_graph, _bond_graph, _num_nbs, _node_mask, cur_bin, cur_label, sp_label)


class Rcnn_Wl_Last(tf.keras.layers.Layer):
    '''This function performs the WLN embedding (local, no attention mechanism)'''

    def __init__(self, hidden_size, depth=3):
        super(Rcnn_Wl_Last, self).__init__()
        self.linearND_atom_embedding = linearND(
            units=hidden_size, init_bias=None)  # hidden_size is default to 300
        self.linearND_nei_atom = linearND(units=hidden_size, init_bias=None)
        self.linearND_nei_bond = linearND(units=hidden_size, init_bias=None)
        self.linearND_self_atom = linearND(units=hidden_size, init_bias=None)
        self.linearND_label_U2 = linearND(units=hidden_size, init_bias=0)
        self.linearND_label_U1 = linearND(units=hidden_size, init_bias=0)
        self.depth = depth

    def build(self, input_shape):
        pass

    def call(self, inputs, **kwargs):
        graph_inputs = inputs
        input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs
        batch_size = input_atom.get_shape().as_list()[0]
        atom_features = tf.nn.relu(self.linearND_atom_embedding(input_atom))
        node_mask = tf.reshape(node_mask, (tf.shape(node_mask)[0], -1, 1))
        node_mask = tf.tile(node_mask, [1, 1, 300])
        layers = []

        for i in range(self.depth):
            fatom_nei = tf.gather_nd(atom_features, atom_graph)
            fbond_nei = tf.gather_nd(input_bond, bond_graph)
            h_nei_atom = self.linearND_nei_atom(fatom_nei)
            h_nei_bond = self.linearND_nei_bond(fbond_nei)
            h_nei = h_nei_atom * h_nei_bond
            mask_nei = tf.reshape(tf.sequence_mask(tf.reshape(num_nbs, [-1]),
                                                   max_nb, dtype=tf.float32),
                                  [batch_size, -1, max_nb, 1])
            f_nei = tf.reduce_sum(h_nei * mask_nei, -2)
            f_self = self.linearND_self_atom(atom_features)
            layers.append(f_nei * f_self * node_mask)   # output
            l_nei = tf.concat([fatom_nei, fbond_nei], 3)
            nei_label = tf.nn.relu(self.linearND_label_U2(l_nei))
            nei_label = tf.reduce_sum(nei_label * mask_nei, -2)
            new_label = tf.concat([atom_features, nei_label], 2)
            new_label = self.linearND_label_U1(new_label)
            atom_features = tf.nn.relu(new_label)   # update atom features
        # kernels = tf.concat(1, layers)
        # atom FPs are the final output after "depth" convolutions
        kernels = layers[-1]
        fp = tf.reduce_sum(kernels, 1)  # molecular FP is sum over atom FPs
        return kernels, fp


class chemFormularGCN(tf.keras.Model):
    def __init__(self, device='cpu:0',
                 checkpoint_directory='log/model_save/checkpoint'):
        super(chemFormularGCN, self).__init__()
        self.device = device
        self.checkpoint_directory = checkpoint_directory
        self.batch_size = 1
        self.NK1 = 5
        self.NK2 = 10
        self.NK3 = 20
        self.sum_acc = 0
        self.sum_acc1 = 0
        self.sum_acc2 = 0
        self.sum_acc3 = 0
        step = 0
        # 0 for init, 1 for training, 2 for eval train_data, 3 for eval eval_data
        self.state = 0
        self.history = {'train_loss': [], 'train_acc': [], 'train_acc1': [], 'train_acc2': [], 'train_acc3': [],
                        'eval_loss': [], 'eval_acc': [], 'eval_acc1': [], 'eval_acc2': [], 'eval_acc3': []}
        self.data_prepare_layer = Data_prepare_layer()
        self.rcnn_wl_last = Rcnn_Wl_Last(hidden_size=300, depth=3)
        self.linearND_att_atom_feature = linearND(
            units=hidden_size, init_bias=None)
        self.linearND_att_bin_feature = linearND(
            units=hidden_size, init_bias=0)
        self.linearND_att_scores = linearND(units=1, init_bias=None)
        self.linearND_atom_feature = linearND(
            units=hidden_size, init_bias=None)
        self.linearND_bin_feature = linearND(units=hidden_size, init_bias=None)
        self.linearND_ctx_feature = linearND(units=hidden_size, init_bias=None)
        self.linearND_scores = linearND(units=5, init_bias=None)

        # Initialize classes to update the mean loss of train and eval
        self.train_loss = tf.keras.metrics.Mean('train_loss')
        self.eval_loss = tf.keras.metrics.Mean('eval_loss')
        self.acc = tf.keras.metrics.Mean('acc')
        self.acc1 = tf.keras.metrics.Mean('acc1')
        self.acc2 = tf.keras.metrics.Mean('acc2')
        self.acc3 = tf.keras.metrics.Mean('acc3')

    def predict(self, input_data, bmask):
        graph_inputs, binary = input_data
        input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask = graph_inputs
        node_mask = tf.expand_dims(node_mask, -1)

        # Perform the WLN embedding
        atom_hiddens, _ = self.rcnn_wl_last(graph_inputs)

        # Calculate local atom pair features as sum of local atom features
        atom_hiddens1 = tf.reshape(
            atom_hiddens, [batch_size, 1, -1, hidden_size])
        atom_hiddens2 = tf.reshape(
            atom_hiddens, [batch_size, -1, 1, hidden_size])
        atom_pair = atom_hiddens1 + atom_hiddens2

        # Calculate attention scores for each pair o atoms
        att_hidden = tf.nn.relu(self.linearND_att_atom_feature(
            atom_pair) + self.linearND_att_bin_feature(binary))
        att_score = self.linearND_att_scores(att_hidden)
        att_score = tf.nn.sigmoid(att_score)

        # Calculate context features using those attention scores
        att_context = att_score * atom_hiddens1
        att_context = tf.reduce_sum(att_context, 2)

        # Calculate global atom pair features as sum of atom context features
        att_context1 = tf.reshape(
            att_context, [batch_size, 1, -1, hidden_size])
        att_context2 = tf.reshape(
            att_context, [batch_size, -1, 1, hidden_size])
        att_pair = att_context1 + att_context2

        # Calculate likelihood of each pair of atoms to form a particular bond order
        pair_hidden = self.linearND_atom_feature(atom_pair) + self.linearND_bin_feature(binary) + \
            self.linearND_ctx_feature(att_pair)

        pair_hidden = tf.nn.relu(pair_hidden)
        pair_hidden = tf.reshape(pair_hidden, [batch_size, -1, hidden_size])
        score = self.linearND_scores(pair_hidden)
        score = tf.reshape(score, [batch_size, -1])

        topk_scores, topk = tf.nn.top_k(score - bmask, k=NK)
        flat_score = tf.reshape(score, [-1])
        return topk, topk_scores, att_score, flat_score

    def accuracy(self, sp_label, cur_topk):
        for i in range(self.batch_size):
            pre = 0
            for j in range(int(tf.shape(sp_label[i])[0])):
                if cur_topk[i, j] in sp_label[i]:
                    pre += 1
            if len(sp_label[i]) == pre:
                self.sum_acc += 1
            pre = 0
            for j in range(self.NK1):
                if cur_topk[i, j] in sp_label[i]:
                    pre += 1
            if len(sp_label[i]) == pre:
                self.sum_acc1 += 1
            pre = 0
            for j in range(self.NK2):
                if cur_topk[i, j] in sp_label[i]:
                    pre += 1
            if len(sp_label[i]) == pre:
                self.sum_acc2 += 1
            pre = 0
            for j in range(self.NK3):
                if cur_topk[i, j] in sp_label[i]:
                    pre += 1
            if len(sp_label[i]) == pre:
                self.sum_acc3 += 1

        # for training
        if self.state == 1:
            delta = 50
            if step % 50 == 0:
                print("Acc: %.4f, Acc@5: %.4f, Acc@10: %.4f, Acc@20: %.4f"
                      % (self.sum_acc / (delta * self.batch_size),
                         self.sum_acc1 / (delta * self.batch_size),
                         self.sum_acc2 / (delta * self.batch_size),
                         self.sum_acc3 / (delta * self.batch_size)))
                sys.stdout.flush()
                self.sum_acc, self.sum_acc1, self.sum_acc2, self.sum_acc3 = 0.0, 0.0, 0.0, 0.0

        # for eval train_data and eval eval_data
        if self.state == 2 or self.state == 3:
            delta = 50
            if step % 50 == 0:
                self.acc(self.sum_acc / (delta * self.batch_size))
                self.acc1(self.sum_acc1 / (delta * self.batch_size))
                self.acc2(self.sum_acc2 / (delta * self.batch_size))
                self.acc3(self.sum_acc3 / (delta * self.batch_size))
                self.sum_acc, self.sum_acc1, self.sum_acc2, self.sum_acc3 = 0.0, 0.0, 0.0, 0.0

    def eval(self, dataset):
        # Compute the loss on the eval data after one epoch
        for _, input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask, binary, label, sp_label in dataset:
            step += 1
            graph_inputs = [input_atom, input_bond,
                            atom_graph, bond_graph, num_nbs, node_mask]
            input_data = [graph_inputs, binary]
            target = [label, sp_label]
            loss = self.loss_fn(input_data, target)
            if self.state == 2:
                self.train_loss(loss)
            if self.state == 3:
                self.eval_loss(loss)

        if self.state == 2:
            self.history['train_loss'].append(self.train_loss.result().numpy())
            self.history['train_acc'].append(self.acc.result().numpy())
            self.history['train_acc1'].append(self.acc1.result().numpy())
            self.history['train_acc2'].append(self.acc2.result().numpy())
            self.history['train_acc3'].append(self.acc3.result().numpy())
            # Reset metrics for train_loss
            self.train_loss.reset_states()
        if self.state == 3:
            self.history['eval_loss'].append(self.eval_loss.result().numpy())
            self.history['eval_acc'].append(self.acc.result().numpy())
            self.history['eval_acc1'].append(self.acc1.result().numpy())
            self.history['eval_acc2'].append(self.acc2.result().numpy())
            self.history['eval_acc3'].append(self.acc3.result().numpy())
            # Reset metrics for eval_loss
            self.eval_loss.reset_states()
        # Reset metrics fo acc
        self.acc.reset_states()
        self.acc1.reset_states()
        self.acc2.reset_states()
        self.acc3.reset_states()

    def loss_fn(self, input_data, target):
        """ Defines the loss function used during
            training.
        """
        label, sp_label = target
        flat_label = tf.reshape(label, [-1])
        bond_mask = tf.cast(tf.not_equal(flat_label, INVALID_BOND), tf.float32)
        flat_label = tf.maximum(0, flat_label)
        bmask = tf.cast(tf.equal(label, INVALID_BOND), tf.float32) * 10000
        topk, topk_scores, att_score, flat_score = self.predict(
            input_data, bmask)
        self.accuracy(sp_label, topk)
        # Train with categorical crossentropy
        loss = tf.nn.sigmoid_cross_entropy_with_logits(
            logits=flat_score, labels=tf.cast(flat_label, tf.float32))
        loss = tf.reduce_sum(loss * bond_mask)
        loss = loss / self.batch_size
        return loss

    def grads_fn(self, input_data, target):
        """ Dynamically computes the gradients of the loss value
            with respect to the parameters of the model, in each
            forward pass.
        """
        with tf.GradientTape() as tape:
            loss = self.loss_fn(input_data, target)
        grads = tape.gradient(loss, model.trainable_variables)
        new_grads, _ = tf.clip_by_global_norm(grads, max_norm)
        return new_grads

    def restore_model(self):
        """ Function to restore trained model.
        """
        with tf.device(self.device):
            # # Run the model once to initialize variables
            # dummy_input = tf.constant(tf.zeros((1, 48, 48, 1)))
            # dummy_pred = self.predict(dummy_input, training=False)
            # Restore the variables of the model
            saver = tf.Saver(self.variables)
            saver.restore(tf.train.latest_checkpoint
                          (self.checkpoint_directory))

    def save_model(self, global_step=0):
        """ Function to save trained model.
        """
        # tf.Saver(self.variables).save(self.checkpoint_directory,
        #                                global_step=global_step)
        tf.saved_model.save(
            self, self.checkpoint_directory + str(global_step))


model = chemFormularGCN()
optimizer = tf.optimizers.Adam(learning_rate=0.001)
model_name = "kaggle_cat_dog-cnn-64x2-{}".format(int(time.time()))

tensorboard = TensorBoard(
    log_dir='log/{}'.format(model_name))


# manual fit
batch_size = 1
num_epochs = 100,
early_stopping_rounds = 10
verbose = 1
train_from_restore = False
history = {'train_loss': [], 'train_acc': [], 'train_acc1': [], 'train_acc2': [], 'train_acc3': [],
                'eval_loss': [], 'eval_acc': [], 'eval_acc1': [], 'eval_acc2': [], 'eval_acc3': []}
 """ Function to train the model, using the selected optimizer and
        for the desired number of epochs. You can either train from scratch
        or load the latest model trained. Early stopping is used in order to
        mitigate the risk of overfitting the network.

        Args:
            training_data: the data you would like to train the model on.
                            Must be in the tf.data.Dataset format.
            eval_data: the data you would like to evaluate the model on.
                        Must be in the tf.data.Dataset format.
            optimizer: the optimizer used during training.
            num_epochs: the maximum number of iterations you would like to
                        train the model.
            early_stopping_rounds: stop training if the loss on the eval
                                   dataset does not decrease after n epochs.
            verbose: int. Specify how often to print the loss value of the network.
            train_from_scratch: boolean. Whether to initialize variables of the
                                the last trained model or initialize them
                                randomly.
    """
   if train_from_restore:
        self.restore_model()

    # Initialize best loss. This variable will store the lowest loss on the
    # eval dataset.
    best_loss = 999
    writer = tf.summary.create_file_writer("log/")

    def record_tensorboard(history, step):
        # other model code would go here
        with writer.as_default():
            tf.summary.scalar(
                "train_loss", history['train_loss'][0], step=step)
            tf.summary.scalar(
                "train_acc", history['train_acc'][0], step=step)
            tf.summary.scalar(
                "train_acc1", history['train_acc1'][0], step=step)
            tf.summary.scalar(
                "train_acc2", history['train_acc2'][0], step=step)
            tf.summary.scalar(
                "train_acc3", history['train_acc3'][0], step=step)
            tf.summary.scalar(
                "eval_loss", history['eval_loss'][0], step=step)
            tf.summary.scalar(
                "eval_acc", history['eval_acc'][0], step=step)
            tf.summary.scalar(
                "eval_acc1", history['eval_acc1'][0], step=step)
            tf.summary.scalar(
                "eval_acc2", history['eval_acc2'][0], step=step)
            tf.summary.scalar(
                "eval_acc3", history['eval_acc3'][0], step=step)

    # Begin training
    with tf.device(self.device):
        for i in range(num_epochs):
            # Training with gradient descent
            self.state = 1
            for _, input_atom, input_bond, atom_graph, bond_graph, num_nbs, node_mask, binary, label, sp_label in training_data:
                step += 1
                print(step)
                graph_inputs = [input_atom, input_bond,
                                atom_graph, bond_graph, num_nbs, node_mask]
                input_data = [graph_inputs, binary]
                target = [label, sp_label]
                grads = self.grads_fn(input_data, target)
                optimizer.apply_gradients(zip(grads, model.variables))

            # Compute the loss on the training_data after one epoch
            self.state = 2
            self.eval(dataset=training_data)
            # Compute the loss on the eval data after one epoch
            self.state = 3
            self.eval(dataset=eval_data)

            # Print train and eval losses
            if (i == 0) | ((i + 1) % verbose == 0):
                print('Train loss at epoch %d: ' %
                      (i + 1), history['train_loss'][-1])
                print('Eval loss at epoch %d: ' %
                      (i + 1), history['eval_loss'][-1])
                record_tensorboard(history, i)
                writer.flush()
                self.save_model(step)

            # Check for early stopping
            if history['eval_loss'][-1] < best_loss:
                best_loss = history['eval_loss'][-1]
                count = early_stopping_rounds
            else:
                count -= 1
            if count == 0:
                break
